# Tested with: Python 3.10, numpy 1.26, torch 2.2
# ------------------------------------------------------------
from __future__ import annotations
import math, random, itertools, copy, os, enum
from dataclasses import dataclass
from typing import Tuple, List, Dict

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import wandb
# ------------------------------------------------------------
# 0.  Utility
# ------------------------------------------------------------
DEVICE = torch.device('cpu')             # set to 'cuda' if desired
TENSOR  = torch.tensor

def uniform_unit_sphere(d: int) -> torch.Tensor:
    v = torch.randn(d)
    v = v / v.norm(p=2)
    return v

def block_rademacher(d: int, coords: List[int]) -> torch.Tensor:
    """Sparse vector with ±1/√K on coords, 0 else."""
    v = torch.zeros(d)
    k = len(coords)
    for i in coords:
        v[i] = 1. if random.random() < 0.5 else -1.
    v /= math.sqrt(k)
    return v

# ------------------------------------------------------------
# 1.  GridWorld environment
# ------------------------------------------------------------
@dataclass
class StepT:
    s:int; r:float; done:bool

class GridWorld:
    """
     grid. Each episode:
      • Toss a coin per cell – if heads, sample N(0,1) reward on that cell
      • Agent starts at centre (row=2,col=2) – 0-indexed
      • 4 actions: 0=up,1=down,2=left,3=right
      • With prob 0.4 the chosen action is *reversed*
      • Stepping outside keeps you at border, reward 0
    """
    N = 20; H = 20
    ACTIONS = {0:(-1,0),1:(1,0),2:(0,-1),3:(0,1)}
    REVERSE = {0:1, 1:0, 2:3, 3:2}

    def __init__(self, wind=0.4, seed=43):
        self.wind = wind
        self.seed=seed
        self.reset()

    def reset(self):
        # Reinitialize the random number generator with the same seed
        self.rng = np.random.RandomState(self.seed)

        self.t = 0
        self.grid_r = np.zeros((self.N,self.N))
        mask = self.rng.rand(self.N,self.N) < 0.5
        self.grid_r[mask] = self.rng.randn(*self.grid_r[mask].shape)
        self.pos = (2,2)
        return self._obs()

    def _obs(self)->int:
        """tabular state id ∈[0,24]"""
        r,c = self.pos
        return r*self.N + c

    def step(self, a:int)->StepT:
        self.t += 1
        # wind reversal
        if random.random() < self.wind:
            a = self.REVERSE[a]
        dr,dc = self.ACTIONS[a]
        r,c = self.pos
        nr,nc = r+dr, c+dc
        if 0<=nr<self.N and 0<=nc<self.N:
            self.pos = (nr,nc)
        s1 = self._obs()
        rwd = self.grid_r[self.pos]
        done = self.t>=self.H
        return StepT(s1,rwd,done)

# ------------------------------------------------------------
# 2.  Policy: tabular soft-max θ∈ℝ^{S×A}
# ------------------------------------------------------------
class TabularSoftmaxPolicy(nn.Module):
    def __init__(self, S=25, A=4):
        super().__init__()
        self.theta = nn.Parameter(torch.zeros(400,A))

    def forward(self, s:torch.LongTensor)->torch.Tensor:
        return F.softmax(self.theta[s], dim=-1)

    def act(self,s:int)->int:
        probs = self.forward(torch.LongTensor([s]))[0].detach().cpu().numpy()
        return int(np.random.choice(len(probs), p=probs))

# ------------------------------------------------------------
# 3.  Simulated human feedback
# ------------------------------------------------------------
class PrefModel(enum.Enum):
    BT      = 'BradleyTerry'
    WEIBULL = 'Weibull'

class HumanPanel:
    def __init__(self, model:PrefModel, M:int):
        self.model = model; self.M=M

    # ----- link functions σ ------------------------------------------------
    @staticmethod
    def _sigma_bt(x):           # logistic
        return 1./(1.+np.exp(-x))
    @staticmethod
    def _sigma_weibull(x):      # anti-sym Weibull
        return math.exp(-math.exp(-x))
    # ---------------------------------------------------------------------

    def prob_prefer(self, r1:float, r0:float)->float:
        x = r1-r0
        if self.model==PrefModel.BT:
            return self._sigma_bt(x)
        else:
            return self._sigma_weibull(x)

    def query(self, R1:float, R0:float)->Tuple[int,float]:
        """
        Return majority vote (0/1) *and* empirical p̂ used in algorithms.
        We sample M independent Bernoulli draws.
        """
        p = self.prob_prefer(R1,R0)
        votes = np.random.rand(self.M) < p
        phat = votes.mean()
        majority = int(phat>=0.5)
        return majority, phat

# ------------------------------------------------------------
# 4.  Trajectory generators & helpers
# ------------------------------------------------------------
def roll_out(env:GridWorld, pol:TabularSoftmaxPolicy):
    s = env.reset()
    traj = []; R=0.
    while True:
        a = pol.act(s)
        step = env.step(a)
        traj.append((s,a,step.r))
        R += step.r
        if step.done: break
        s = step.s
    return traj, R

def sample_pairs(policy_a, policy_b, N):
    env = GridWorld()
    for _ in range(N):
        tau0,R0 = roll_out(env, policy_a)
        tau1,R1 = roll_out(env, policy_b)
        yield (tau0,R0, tau1,R1)

# ------------------------------------------------------------
# 5.  ZPG (Alg-1)
# ------------------------------------------------------------
@dataclass
class ZPGConfig:
    T:int=1000; N:int=1000; M:int=1000
    μ:float=0.1; α:float=0.05
    link:PrefModel=PrefModel.BT
    trim:float=1e-2

class ZPG:
    def __init__(self, cfg:ZPGConfig, S=25,A=4):
        self.cfg = cfg
        self.policy = TabularSoftmaxPolicy(S,A).to(DEVICE)
        self.d = self.policy.theta.numel()
        self.panel = HumanPanel(cfg.link, cfg.M)
        self.total_env_steps =0

    def _σinv(self, p:float)->float:
        p = np.clip(p, self.cfg.trim, 1.-self.cfg.trim)
        if self.cfg.link==PrefModel.BT:
            return math.log(p/(1-p))
        else:                         # inverse of Weibull CDF
            return -math.log(-math.log(p+1e-12)+1e-12)

    def run(self):
        hist = []
        θview = self.policy.theta.view(-1)      # flat view
        for t in range(self.cfg.T):
            v = uniform_unit_sphere(self.d)
            θ_plus = θview + self.cfg.μ*v
            # create perturbed policy object
            pol_plus = copy.deepcopy(self.policy)
            pol_plus.theta.data = θ_plus.view_as(self.policy.theta).data.clone()

            # collect preferences
            gap_est = 0.
            for tau0,R0,tau1,R1 in sample_pairs(self.policy, pol_plus, self.cfg.N):
                self.total_env_steps += len(tau0) + len(tau1)
                _, phat = self.panel.query(R1,R0)
                gap_est += self._σinv(phat)
            gap_est /= self.cfg.N

            # zeroth-order grad
            g = (self.d/self.cfg.μ) * gap_est * v
            # Add normalization (optional)
            g_norm = g / (np.linalg.norm(g) + 1e-8)
            # SGD ascent
            θview.data += self.cfg.α * torch.tensor(g_norm, dtype=torch.float32)

            if (t+1)%1==0:
                # simple evaluation
                _,R_eval = roll_out(GridWorld(), self.policy)
                hist.append(R_eval)
                print(f"[ZPG] iter {t+1:4d}  return={R_eval:6.3f}")
                # Log to wandb
                wandb.log({"return": R_eval, "ZPG_env_steps": self.total_env_steps}, step=self.total_env_steps)
        return hist

# ------------------------------------------------------------
# 6.  ZBCPG (Alg-2)
# ------------------------------------------------------------
@dataclass
class ZBCPGConfig(ZPGConfig):
    K:int=20                 # coords per block

class ZBCPG(ZPG):
    def __init__(self,cfg:ZBCPGConfig,S=25,A=4):
        super().__init__(cfg,S,A)
        self.cfg:ZBCPGConfig = cfg
        self.total_env_steps=0

    def run(self):
        θview = self.policy.theta.view(-1)
        d = self.d
        hist=[]
        for t in range(self.cfg.T):
            coords = random.sample(range(d), self.cfg.K)
            v = block_rademacher(d, coords)
            θ_plus = θview + self.cfg.μ*v
            pol_plus = copy.deepcopy(self.policy)
            pol_plus.theta.data = θ_plus.view_as(self.policy.theta).data.clone()

            gap_est = 0.
            for tau0,R0,tau1,R1 in sample_pairs(self.policy, pol_plus, self.cfg.N):
                _,phat = self.panel.query(R1,R0)
                self.total_env_steps += len(tau0) + len(tau1)
                gap_est += self._σinv(phat)

            gap_est /= self.cfg.N

            g = (d/(self.cfg.μ)) * gap_est * v
            θview.data += self.cfg.α * torch.tensor(g,dtype=torch.float32)

            if (t+1)%1==0:
                _,R_eval = roll_out(GridWorld(), self.policy)
                hist.append(R_eval)
                print(f"[ZBCPG] iter {t+1:4d} return={R_eval:6.3f}")
                # Log to wandb
                wandb.log({"return": R_eval, "env_steps": (t+1) * self.cfg.N}, step=self.total_env_steps)
        return hist

# ------------------------------------------------------------
# 7.  Reward-model + PPO baseline
# ------------------------------------------------------------
class TabularReward(nn.Module):
    def __init__(self,S=25,A=4):
        super().__init__()
        self.w = nn.Parameter(torch.zeros(400,A))

    def forward(self,s,a):
        return self.w[s,a]

class RMPPO:
    def __init__(self,S=25,A=4, panel:HumanPanel=None,
                 traj_pairs:int=500_000, ppo_iters=1000,
                 kl_beta=0.1, γ=1.0, lam=0.95):
        self.S=S; self.A=A
        self.behaviour = TabularSoftmaxPolicy(S,A)
        self.panel = panel
        self.traj_pairs = traj_pairs
        self.reward_net = TabularReward(S,A)
        self.γ=γ; self.lam=lam
        self.kl_beta=kl_beta
        self.ppo_iters=ppo_iters
        self.total_env_steps = 0

    # --- reward model training (log-likelihood MLE) -------------
    def train_rm(self):
        opt = optim.Adam(self.reward_net.parameters(), lr=3e-2)
        batch= []
        env=GridWorld()
        print("[RM] collecting data ...")
        for _ in range(self.traj_pairs):
            tau0,R0 = roll_out(env,self.behaviour)
            tau1,R1 = roll_out(env,self.behaviour)
            self.total_env_steps+=len(tau0)+len(tau1)
            majority, _ = self.panel.query(R1,R0)
            batch.append((tau0,tau1, majority))
        print(f"[RM] data size = {len(batch)}")

        def R(traj):
            rs = [ self.reward_net.forward(s,a) for s,a,_ in traj ]
            return torch.stack(rs).sum()

        for epoch in range(5):
            random.shuffle(batch)
            for i,(tau0,tau1,y) in enumerate(batch):
                opt.zero_grad()
                R1 = R(tau1); R0 = R(tau0)
                loss = F.binary_cross_entropy_with_logits(
                    torch.stack([R1-R0]), torch.tensor([float(y)],dtype=torch.float32))
                loss.backward(); opt.step()
            print(f"[RM] epoch {epoch+1} done.")
    # -------------------------------------------------------------

    # --- PPO phase ----------------------------------------------
    def run(self):
        self.train_rm()
        pol = TabularSoftmaxPolicy(self.S,self.A)
        old_pol = copy.deepcopy(pol)
        opt = optim.Adam(pol.parameters(), lr=2e-2)
        env = GridWorld()
        for it in range(self.ppo_iters):
            # collect rollout
            traj, _ = roll_out(env, pol)
            T = len(traj)
            self.total_env_steps += T
            s,a,r = zip(*traj)
            s = torch.LongTensor(s)
            a = torch.LongTensor(a)
            # pseudo rewards from RM
            with torch.no_grad():
                r_hat = self.reward_net.forward(s,a).cpu().numpy()
            # GAE advantages
            adv = np.zeros(T); lastgaelam=0; v=0
            for t in reversed(range(T)):
                delta = r_hat[t] + self.γ*v - v  # v=0 since no value fn
                lastgaelam = delta + self.γ*self.lam*lastgaelam
                adv[t] = lastgaelam
            adv = torch.tensor(adv,dtype=torch.float32)

            # PPO loss
            opt.zero_grad()
            logp = F.log_softmax(pol.theta[s],dim=-1).gather(1,a.view(-1,1)).squeeze()
            with torch.no_grad():
                logp_old = F.log_softmax(old_pol.theta[s],dim=-1).gather(1,a.view(-1,1)).squeeze()
            ratio = torch.exp(logp-logp_old)
            clip = torch.clamp(ratio, 0.8,1.2)
            ppo_loss = -(torch.min(ratio*adv, clip*adv)).mean()
            kl = (torch.exp(logp_old)*(logp_old-logp)).mean()
            loss = ppo_loss + self.kl_beta*kl
            loss.backward(); opt.step()
            old_pol.load_state_dict(pol.state_dict())

            if (it+1)%1==0:
                _,R_eval = roll_out(GridWorld(), pol)
                print(f"[PPO] iter {it+1:4d} return={R_eval:6.3f}")
                # Log to wandb
                wandb.log({"return": R_eval, "RM_env_steps": self.total_env_steps}, step=self.total_env_steps)
# ------------------------------------------------------------
# 8.  DFA & Online-DFA (simplified)
# ------------------------------------------------------------
class DFA:
    def __init__(self,online=False,beta=0.000001,S=25,A=4,
                 panel:HumanPanel=None,N_pairs=1000,iters=1000):
        self.online=online
        self.beta=beta; self.N=N_pairs; self.iters=iters
        self.panel=panel
        self.pol = TabularSoftmaxPolicy(S,A)
        self.ref = copy.deepcopy(self.pol)     # π₀
        self.env = GridWorld()
        self.total_env_steps=0

    def run(self):
        opt = optim.Adam(self.pol.parameters(), lr=3e-2)
        for t in range(self.iters):
            batch=[]
            for tau0,R0,tau1,R1 in sample_pairs(self.pol,self.pol,self.N):
                self.total_env_steps+=len(tau0)+len(tau1)
                _,phat = self.panel.query(R1,R0)
                # Store just the state-action pairs without rewards
                tau0_sa = [(s,a) for s,a,_ in tau0]
                tau1_sa = [(s,a) for s,a,_ in tau1]
                batch.append((tau0_sa,tau1_sa,phat))
            opt.zero_grad()
            loss=0.
            for tau0,tau1,p in batch:
                # log-likelihood under DFA surrogate (token-level simpl.)
                probs_tau1 = list(self._prob_traj(tau1))
                probs_tau0 = list(self._prob_traj(tau0))

                # Sum the log probabilities
                log_tau1 = sum(torch.log(prob) for prob in probs_tau1)
                log_tau0 = sum(torch.log(prob) for prob in probs_tau0)

                logit = self.beta*(log_tau1 - log_tau0)
                loss += F.binary_cross_entropy_with_logits(
                    logit.view(1), torch.tensor([p],dtype=torch.float32))
            loss = loss/len(batch)
            loss.backward(); opt.step()

            if (t+1)%1==0:
                _,R_eval = roll_out(GridWorld(), self.pol)
                tag = "oDPO" if self.online else "DFA"
                print(f"[{tag}] iter {t+1:4d} return={R_eval:6.3f}")
                # Log to wandb
                method_name = "Online-DFA" if self.online else "DFA"
                wandb.log({"return": R_eval, "DPO_env_steps":self.total_env_steps }, step=self.total_env_steps)

    def _prob_traj(self,tau):
        net = self.pol
        for s,a in tau:
            probs = F.softmax(net.theta[s],dim=-1)
            yield probs[a]

    # def _kl(self):
    #     p = F.softmax(self.pol.theta,dim=-1)
    #     q = F.softmax(self.ref.theta,dim=-1)
    #     kl = (p*(p.log()-q.log())).sum()/25.
    #     return kl




class OraclePPO:
    def __init__(self, S=25, A=4, ppo_iters=1000,
                 kl_beta=0.1, γ=1.0, lam=0.95):
        self.S = S
        self.A = A
        self.policy = TabularSoftmaxPolicy(S, A)
        self.old_policy = copy.deepcopy(self.policy)
        self.γ = γ
        self.lam = lam
        self.kl_beta = kl_beta
        self.ppo_iters = ppo_iters
        self.total_env_steps = 0

    def run(self):
        opt = optim.Adam(self.policy.parameters(), lr=2e-2)
        env = GridWorld()
        hist = []

        for it in range(self.ppo_iters):
            # collect rollout with true rewards
            traj, _ = roll_out(env, self.policy)
            T = len(traj)
            self.total_env_steps += T
            s, a, r = zip(*traj)
            s = torch.LongTensor(s)
            a = torch.LongTensor(a)
            r = torch.tensor(r, dtype=torch.float32)

            # GAE advantages using true rewards
            adv = np.zeros(T)
            lastgaelam = 0
            v = 0  # No value function, so v=0
            for t in reversed(range(T)):
                delta = r[t] + self.γ * v - v
                lastgaelam = delta + self.γ * self.lam * lastgaelam
                adv[t] = lastgaelam
            adv = torch.tensor(adv, dtype=torch.float32)

            # PPO loss
            opt.zero_grad()
            logp = F.log_softmax(self.policy.theta[s], dim=-1).gather(1, a.view(-1, 1)).squeeze()
            with torch.no_grad():
                logp_old = F.log_softmax(self.old_policy.theta[s], dim=-1).gather(1, a.view(-1, 1)).squeeze()

            ratio = torch.exp(logp - logp_old)
            clip = torch.clamp(ratio, 0.8, 1.2)
            ppo_loss = -(torch.min(ratio * adv, clip * adv)).mean()
            kl = (torch.exp(logp_old) * (logp_old - logp)).mean()
            loss = ppo_loss + self.kl_beta * kl

            loss.backward()
            opt.step()

            # Update old policy
            self.old_policy.load_state_dict(self.policy.state_dict())

            if (it + 1) % 1 == 0:
                _, R_eval = roll_out(GridWorld(), self.policy)
                hist.append(R_eval)
                print(f"[Oracle-PPO] iter {it + 1:4d} return={R_eval:6.3f}")
                # Log to wandb
                wandb.log({"return": R_eval, "OraclePPO_env_steps": self.total_env_steps},
                          step=self.total_env_steps)

        return hist




# ------------------------------------------------------------
# 9.  Main
# ------------------------------------------------------------
if __name__ == "__main__":
    random.seed(3); np.random.seed(3); torch.manual_seed(3)


    panel_bt = HumanPanel(PrefModel.BT, M=1)
    beta=0.00001
    # Initialize wandb
    # wandb.init(project="rlhf-zo-gridworld", name=f"comparison-SPO")
    wandb.init(project="rlhf-zo-gridworld", name=f"comparison20-DPO1-{beta}")
    # wandb.init(project="rlhf-zo-gridworld", name=f"comparison-RM1")

    # ---------- Oracle PPO (true reward) ----------------------------------
    # print("\n=== Running Oracle PPO (true reward) ===")
    # oracle_ppo = OraclePPO(ppo_iters=100000)
    # oracle_ppo.run()

    # # ---------- ZPG -------------------------------------------------------
    # cfg_zpg = ZPGConfig(T=1000000, N=10, M=1000, μ=0.1, α=0.01)
    # zpg = ZPG(cfg_zpg); zpg_hist = zpg.run()

    # # ---------- ZBCPG -----------------------------------------------------
    # cfg_zb = ZBCPGConfig(T=200000,N=1,M=1000,μ=0.1,α=0.05,K=20)
    # zb = ZBCPG(cfg_zb); zb_hist = zb.run()
    #
    # # ---------- PPO baseline ---------------------------------------------
    # rmppo = RMPPO(panel=panel_bt, traj_pairs=5000, ppo_iters=100000)
    # rmppo.run()

    # # Initialize wandb
    # wandb.init(project="rlhf-zo-gridworld", name="comparison-DFA")
    # # ---------- DFA -------------------------------------------------------
    dfa = DFA(online=False, panel=panel_bt, iters=100000,N_pairs=1, beta=beta)
    dfa.run()

    # Close wandb run
    wandb.finish()